#################################################################################
# Replication file for:                                                         #
# "Balancing Precision and Retention in Experimental Design"                    #
#                                                                               #
# Gustavo Diaz                                                                  #
# Northwestern University                                                       #
# gustavo.diaz@northwestern.edu                                                 #
#                                                                               #
# Erin L. Rossiter                                                              #
# University of Notre Dame                                                      #
# erossite@nd.edu                                                               #
#                                                                               #
# This file produces Figure 1.                                                  #
#################################################################################

# Data ---- 
df_dh <- readRDS("./data/processed_data/DietrichHayesReplication-clean.rds")
df_bg <- readRDS("./data/processed_data/BayramGrahamReplication-clean.rds")
df_th <- readRDS("./data/processed_data/TappinHewittReplication-clean.rds")

# Correlations reported in Section 4.2 ----
cor.test(df_dh$speech_approve_scaled[df_dh$design %in% c(2,3)],
         df_dh$dv_quasipre_scaled[df_dh$design %in% c(2,3)],
         use = "pairwise.complete.obs")
cor.test(df_bg$delpref_post[df_dh$design %in% c(2,3)],
         df_bg$delpref_pre[df_dh$design %in% c(2,3)],
         use = "pairwise.complete.obs")
cor.test(df_th$pension_followup_scaled[df_th$design %in% c(2,3)],
         df_th$pension_pre_scaled[df_th$design %in% c(2,3)],
         use = "pairwise.complete.obs")

# Simulation -----
# Prep - create an indicator for sample loss (see 4_descriptives.R)
df_dh$any_loss <- (df_dh$Finished == 0)
df_bg$any_loss <- (df_bg$Finished == 0)
df_th$any_loss <- ((df_th$finished_w1 != 1) | 
                     (df_th$finished_w2 != 1) | 
                     (df_th$finished_w3 != 1) | 
                     is.na(df_th$start_w2) | 
                     is.na(df_th$start_w3))

# Prep - make variable for block id consistent across studies
df_th <- df_th %>%
  mutate(block = block_id)

estimation_fun <- function(pct_loss, des, exp_data, dv, dv_pre, starting_n, exp){
  
  # estimate model under alternative designs
  if(des == 2){
    df_subset <- exp_data %>%
      # starting subset used to estimate treatment effects
      filter(design == 2 & !any_loss & !is.na(Z)) %>%
      # choose random portion to omit to get a consistent starting n
      # (we begin by wiping the slate clean of explicit loss)
      sample_n(starting_n, replace = F) %>%
      # now chose a random portion according to sim. pct_loss
      sample_n(ceiling(starting_n - starting_n*pct_loss), replace = F)
    
    formula_str <- paste(dv, "~ Z + ", dv_pre)    
    m <- estimatr::lm_robust(as.formula(formula_str),
                              data = df_subset,
                              subset = design == 2)
  }
  if(des == 3){
    df_subset <- exp_data %>%
      # starting subset used to estimate treatment effects
      filter(design == 3 & !any_loss & !is.na(Z)) %>%
      # choose random portion to delete to get to starting n
      sample_n(starting_n, replace = F) %>%
      # now chose a random portion according to sim. pct_loss
      sample_n(ceiling(starting_n - starting_n*pct_loss), replace = F)
    
    if(exp == "th"){
      formula_str <- paste(dv, "~ Z + ", dv_pre)  
    }else{
      formula_str <- paste(dv, "~ Z")  
    }

    m <- estimatr::lm_robust(as.formula(formula_str),
                              fixed_effects = ~ block,
                              data = df_subset, 
                              subset = design == 3)
  }

  # store outputs
  data.frame(exp = exp,
             design = des,
             n = m$n,
             coef = m$coef["Z"],
             se = m$std.error["Z"],
             lower_ci = m$conf.low["Z"],
             upper_ci = m$conf.high["Z"],
             pct_loss = pct_loss)
}

# run simulation
inputs <- expand.grid(pct_loss = seq(0, .5, .05),
                      des = c(2, 3))
n_iters <- 1000
set.seed(123)
sim_iters_dh <- map_dfr(1:nrow(inputs), function(i) {
  # Repeat estimation_fun n_iters times for each combination
  map_dfr(1:n_iters, ~ estimation_fun(pct_loss = inputs$pct_loss[i],
                                   des = inputs$des[i],
                                   exp_data = df_dh,
                                   dv = "speech_approve_scaled",
                                   dv_pre = "dv_quasipre_scaled",
                                   starting_n = 208,
                                   exp = "dh"))
})

sim_iters_bg <- map_dfr(1:nrow(inputs), function(i) {
  # Repeat estimation_fun n_iters times for each combination
  map_dfr(1:n_iters, ~ estimation_fun(pct_loss = inputs$pct_loss[i],
                                  des = inputs$des[i],
                                  exp_data = df_bg,
                                  dv = "delpref_post",
                                  dv_pre = "delpref_pre",
                                  starting_n = 299,
                                  exp = "bg"))
})
sim_iters_th <- map_dfr(1:nrow(inputs), function(i) {
  # Repeat estimation_fun n_iters times for each combination
  map_dfr(1:n_iters, ~ estimation_fun(pct_loss = inputs$pct_loss[i],
                                      des = inputs$des[i],
                                      exp_data = df_th,
                                      dv = "pension_followup_scaled",
                                      dv_pre = "pension_pre_scaled",
                                      starting_n = 608,
                                      exp = "th"))
})


# Standard design
# simulate many iters at starting sample size to estimate average SE
standard_sim_dh <- plyr::adply(1:n_iters, 1, function(i) {
  df_subset <- df_dh %>%
    # starting subset used to estimate treatment effects
    filter(design == 1 & !any_loss & !is.na(Z)) %>%
    # choose random portion to delete to get to starting n
    sample_n(208, replace = F)
  
  m <- lm_robust(speech_approve_scaled ~ Z,
                 data = df_subset,
                 subset = design == 1)
  
  # store outputs
  data.frame(n = m$n,
             coef = m$coef["Z"],
             se = m$std.error["Z"],
             lower_ci = m$conf.low["Z"],
             upper_ci = m$conf.high["Z"],
             exp = "dh")
})


standard_sim_bg <- plyr::adply(1:n_iters, 1, function(i) {
  df_subset <- df_bg %>%
    # starting subset used to estimate treatment effects
    filter(design == 1 & !any_loss & !is.na(Z)) %>%
    # choose random portion to delete to get to starting n
    sample_n(299, replace = F)
  
  m <- lm_robust(delpref_post ~ Z,
                 data = df_subset,
                 subset = design == 1)
  
  # store outputs
  data.frame(n = m$n,
             coef = m$coef["Z"],
             se = m$std.error["Z"],
             lower_ci = m$conf.low["Z"],
             upper_ci = m$conf.high["Z"],
             exp = "bg")
})


standard_sim_th <- plyr::adply(1:n_iters, 1, function(i) {
  df_subset <- df_th %>%
    # starting subset used to estimate treatment effects
    filter(design == 1 & !any_loss & !is.na(Z)) %>%
    # choose random portion to delete to get to starting n
    sample_n(608, replace = F)
  
  m <- lm_robust(pension_followup_scaled ~ Z,
                              data = df_subset,
                              subset = design == 1)
  
  # store outputs
  data.frame(n = m$n,
             coef = m$coef["Z"],
             se = m$std.error["Z"],
             lower_ci = m$conf.low["Z"],
             upper_ci = m$conf.high["Z"],
             exp = "th")
})



# Summaries
sim_iters_all <- bind_rows(sim_iters_dh, sim_iters_bg, sim_iters_th)
standard_sim_all <- bind_rows(standard_sim_dh,  standard_sim_bg, standard_sim_th)

# summary for standard design
standard_summary <- standard_sim_all %>%
  group_by(exp) %>%
  summarise(se_mean = mean(se), .groups = "drop")

# summary for designs 2 and 3 to plot
# clean up values for plot
sim_summary <- sim_iters_all %>%
  group_by(exp, pct_loss, design) %>%
  summarise(se_mean = mean(se),
            se_sd = sd(se),
            se_max = max(se),
            se_min = min(se),
            se_975 = quantile(se, .975),
            se_025 = quantile(se, .025),
            .groups = "drop") %>%
  # percent change relative to standard design = 
  # [(alt design se - standard se) / standard se] × 100
  mutate(pct_change = case_when(
    exp == "dh" ~ ((se_mean - standard_summary$se_mean[standard_summary$exp == "dh"])/standard_summary$se_mean[standard_summary$exp == "dh"])*100,
    exp == "bg" ~ ((se_mean - standard_summary$se_mean[standard_summary$exp == "bg"])/standard_summary$se_mean[standard_summary$exp == "bg"])*100,
    exp == "th" ~ ((se_mean - standard_summary$se_mean[standard_summary$exp == "th"])/standard_summary$se_mean[standard_summary$exp == "th"])*100,
  )) %>%
  mutate(pct_change_025 = case_when(
    exp == "dh" ~ ((se_025 - standard_summary$se_mean[standard_summary$exp == "dh"])/standard_summary$se_mean[standard_summary$exp == "dh"])*100,
    exp == "bg" ~ ((se_025 - standard_summary$se_mean[standard_summary$exp == "bg"])/standard_summary$se_mean[standard_summary$exp == "bg"])*100,
    exp == "th" ~ ((se_025 - standard_summary$se_mean[standard_summary$exp == "th"])/standard_summary$se_mean[standard_summary$exp == "th"])*100,
  )) %>%
  mutate(pct_change_975 = case_when(
    exp == "dh" ~ ((se_975 - standard_summary$se_mean[standard_summary$exp == "dh"])/standard_summary$se_mean[standard_summary$exp == "dh"])*100,
    exp == "bg" ~ ((se_975 - standard_summary$se_mean[standard_summary$exp == "bg"])/standard_summary$se_mean[standard_summary$exp == "bg"])*100,
    exp == "th" ~ ((se_975 - standard_summary$se_mean[standard_summary$exp == "th"])/standard_summary$se_mean[standard_summary$exp == "th"])*100,
  )) %>%
  mutate(design = factor(design, levels = c(2,3),
                         labels = c("Pre-post", "Block Randomized (and Pre-Post for Tappin & Hewitt)"))) %>%
  mutate(exp = factor(exp, levels = c("dh", "bg", "th"),
                      labels = c("Dietrich & Hayes",
                                 "Bayram & Graham",
                                 "Tappin & Hewitt")))

implicit_loss <- ggplot(sim_summary, aes(x = pct_loss,
                                               y = pct_change,
                                               color = factor(design),
                                               shape = factor(design))) +
  facet_wrap(~exp, scales = "fixed", nrow = 1) +
  geom_pointrange(aes(ymin = pct_change_025, ymax = pct_change_975), size = 0.3,
                  position = position_dodge2(width = 0.03)) +
  # Horizontal line at 0
  geom_hline(aes(yintercept = 0),
             linetype = "dashed",
             color = "black") +
  theme_minimal() +
  theme(
    panel.border = element_rect(color = "black", fill = NA),
    panel.grid.minor.x = element_blank(),
    text = element_text(size = 12),
    axis.title = element_text(size = 12),
    legend.title = element_text(size = 12),
    legend.text = element_text(size = 12),
    strip.text = element_text(size = 12),
    legend.position = "bottom"
  ) +
  labs(x = "Implicit Sample Loss",
       y = "Percentage Change in Standard Error\nRelative to Standard Design",
       color = "Design") + 
  scale_color_manual(values = c("grey50", "grey20")) +
  scale_shape_manual(values = c(16,17)) + 
  guides(color = guide_legend(title = "Design", override.aes = list(shape = c(16,17)),
                              ncol = 2),
         shape = "none") +
  scale_y_continuous(limits = c(-50, 90),
                         breaks = seq(-50, 90, by = 25),
                         labels = paste0(seq(-50, 90, by = 25), "%")) +
  scale_x_continuous(breaks = seq(0, .50, by = .10),
                     labels = paste0(seq(0, 50, by = 10), "%"))

ggsave("figures/fig1.pdf", implicit_loss, width = 10, height = 4)
